#include "statsxx/machine_learning/NeuralNet/Neuron.hpp"

// STL
#include <iostream>

// stats++
#include "statsxx/machine_learning/activation_functions.hpp" // activation_function::Logistic, activation_function::softplus


inline NEURON::NEURON()
{
    m_isActivated      = false;

    type               = NEURON::TYPE::HIDDEN;

    act_func           = NEURON::ACTIVATION_FUNC::IDENTITY;

    m_ID               = 0;
}

inline NEURON::NEURON(int ID, NEURON::TYPE tp, NEURON::ACTIVATION_FUNC af)
{
    m_isActivated      = false;

    type               = tp;

    act_func           = af;

    m_ID               = ID;
}

inline NEURON::~NEURON() {};


//========================================================================
//========================================================================
//
// NAME: void NEURON::clear_memory()
//
// DESC: Clears the neuron's memory (for recurrent neurons).
//
//========================================================================
//========================================================================
inline void NEURON::clear_memory()
{
    m_output.clear();
}


//========================================================================
//========================================================================
//
// NAME: void NEURON::clear_errors()
//
// DESC: Clears the neuron's error history.
//
//========================================================================
//========================================================================
inline void NEURON::clear_errors()
{
    m_errors.clear();
}


//========================================================================
//========================================================================
//
// NAME: void NEURON::deactivate()
//
// DESC: De-activates the neuron. Memory is preserved.
//
//========================================================================
//========================================================================
inline void NEURON::deactivate()
{
    m_isActivated = false;
}


//========================================================================
//========================================================================
//
// NAME: void NEURON::activate()
//
// DESC: Parses the
//
//========================================================================
//========================================================================
inline void NEURON::activate()
{
    activation_function::Logistic    logistic;
    activation_function::softplus    softplus;
    activation_function::tanh_skewed tanh_skewed;

    switch(act_func)
    {
        case NEURON::ACTIVATION_FUNC::IDENTITY:
            break;
        case NEURON::ACTIVATION_FUNC::LOGISTIC:
            //activation_function::Logistic logistic;
            m_output.front() = logistic.f(m_output.front());
            break;
        case NEURON::ACTIVATION_FUNC::TANH:
            m_output.front() = tanh_skewed.f(m_output.front());
            break;
        case NEURON::ACTIVATION_FUNC::SOFTPLUS:
            //activation_function::softplus softplus;
            m_output.front() = softplus.f(m_output.front());
            break;
        case NEURON::ACTIVATION_FUNC::SOFTMAX:
            std::cout << "Warning in NEURON::activate(): Cannot calculate softmax function without knowledge of other neurons! Setting output to input." << std::endl;
            //exit(0);
            break;
        default:
            break;
    }

    m_isActivated = true;
}


//========================================================================
//========================================================================
//
// NAME: void NEURON::activate_error()
//
// DESC: Calculates the error at a neuron, which is detailed in the notes below. For an inactive neuron, the error is simply sum_{k<>K}{delta_k*W_jk} (again, see the notes). To get the error, we need to multiply the error by F'(x).
//
// NOTES:
//     ! the error for a neuron (except input, output, or bias) is the sum of the products between the errors of the neurons in the next layer and the weights of the connections to those neurons, multiplied by the derivate of the activation function, so:
//          delta_j = (sum_{k<>K}{delta_k*W_jk})*F'(x_j)
//               x_j     == value of neuron at node j
//               delta_k == error of downstream node k (NOT necessarily output layer)
//               W_jk    == weight connecting node j to node k
//
//========================================================================
//========================================================================
inline void NEURON::activate_error()
{
    // WE NEED THE INVERSE ACTIVATION FUNCTION TO GET THE INPUT TO THIS
    double dFx = 0.0;

    activation_function::Logistic    logistic;
    activation_function::softplus    softplus;
    activation_function::tanh_skewed tanh_skewed;

    switch(act_func)
    {
        case NEURON::ACTIVATION_FUNC::IDENTITY:
            dFx = 1.0;
            break;
        case NEURON::ACTIVATION_FUNC::LOGISTIC:
            dFx = logistic.df(logistic.inv(m_output[m_errors.size()-1]));
            break;
        case NEURON::ACTIVATION_FUNC::TANH:
            dFx = tanh_skewed.df( tanh_skewed.inv( m_output[m_errors.size()-1] ) );
            break;
        case NEURON::ACTIVATION_FUNC::SOFTPLUS:
            dFx = softplus.df(softplus.inv(m_output[m_errors.size()-1]));
            break;
        case NEURON::ACTIVATION_FUNC::SOFTMAX:
            std::cout << "Error in NEURON::activate_error(): Unsure how to calculate dsoftmax!" << std::endl;
            exit(0);
        default:
            break;
    }

    m_errors.back() *= dFx;

    m_isActivated   = true;
}


//========================================================================
//========================================================================
//
// NAME: bool NEURON::is_activated()
//
// DESC: Returns whether the neuron is activated. Note that a neuron can be de-activated yet have a memory.
//
//========================================================================
//========================================================================
inline bool NEURON::is_activated()
{
    return m_isActivated;
}


//========================================================================
//========================================================================
//
// NAME: double NEURON::get_output(unsigned int tdelay)
//
// DESC: Returns the neuron output at time t - tdelay. Note that a neuron can be ``forgetful'' or have a ``bad memory'', in the sense previous output does not exist.
//
//========================================================================
//========================================================================
inline double NEURON::get_output(unsigned int tdelay)
{
    if( (tdelay + 1) > m_output.size() )
    {
        return 0.0;
    }
    else
    {
        return m_output[tdelay];
    }
}


//========================================================================
//========================================================================
//
// NAME: double NEURON::get_error(unsigned int tdelay)
//
// DESC: Returns the neuron error at time t - tdelay. Note that a neuron can be ``forgetful'' or have a ``bad memory'', in the sense previous output does not exist.
//
//========================================================================
//========================================================================
inline double NEURON::get_error(unsigned int tdelay)
{
    if( (tdelay + 1) > m_errors.size() )
    {
        return 0.0;
    }
    else
    {
        return m_errors[m_errors.size()-1 - tdelay];
    }
}


//========================================================================
//========================================================================
//
// NAME: void NEURON::override_activation(bool flag)
//
// DESC: Overrides the value of m_isActivated, typically only set by calling activate() or activate_error();
//
//========================================================================
//========================================================================
inline void NEURON::override_activation(bool flag)
{
    m_isActivated = flag;
}


//========================================================================
//========================================================================
//
// NAME: void NEURON::(unsigned int tdelay, double val)
//
// DESC: Overrides the stored output at a delay of tdelay. This is important in the case that the combined output of a collection of neurons follows some function (e.g., softmax in the case of 1-of-n classification). Note that if tdelay is greater than the size of m_output, m_output is resized to tdelay and filled with 0s in-between (i.e., it will appear to have ``forgotten'' these values).
//
//========================================================================
//========================================================================
inline void NEURON::override_output(unsigned int tdelay, double val)
{
    m_output.resize( (tdelay+1), 0.0 );

    m_output[tdelay] = val;
}
